import os, time, csv, itertools, psutil
import subprocess
import collections

from threading import Thread
from .common import *
from .collect import *


class CtrMainCollect(object):
    def __init__(
        self,
        interval,
        output,
        collects,
        driver,
        cgroupPrefix="/sys/fs/cgroup",
        docker=True,
        containerd=False,
    ):
        self.cgroupType = "v2" if cgroupHelper.supportCgv2() else "v1"
        self.interval = interval
        # self.timeout = timeout  # unused attribute?
        self.output = output
        self.useDocker = docker
        self.useContainerd = containerd
        self.driver = driver
        self.cgroupPrefix = cgroupPrefix

        self.supportTypes = supHelper.getCtrSupports()
        self.collects = self.filter_collects(collects)

        self.ctrIDs = self.get_ctr_IDs()
        self.ctrCollects = self.init_ctr_collects()
        self.prepare_output_file()

    def prepare_output_file(self):
        """create a output file for each container"""
        if not self.ctrIDs:
            print("no containers found")
            return

        # TODO: ctrs may change.....
        columns = self.prepare_columns()

        for c in self.ctrCollects:
            path = c.outputFile
            if not os.path.exists(path):
                # create file, don't use mknod bcoz Macos doesnt support it
                with open(path, "w"):
                    pass

            with open(path, "w") as csvfile:
                csvWriter = csv.writer(csvfile, delimiter=",")
                csvWriter.writerow(columns)

    def get_ctr_IDs(self):
        # TODO: match k8s namespace if needed
        if self.useDocker:
            cmd = "docker ps --no-trunc | grep -v pause | awk '{print $1}' | tail -n +2"
        elif self.useContainerd:
            cmd = "ctr containers ls | grep -v pause | awk '{print $1}' | tail -n +2"

        try:
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
            out, _ = proc.communicate()
        except subprocess.CalledProcessError as e:
            print(
                "%s : execute %s failed, %s"
                % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), cmd, e.output)
            )
            return

        return out.decode("UTF-8").rstrip().split("\n")

    def prepare_columns(self):
        columns = ["timestamp"]
        classes = supHelper.getCtrClasses()
        for cType in self.collects:
            if cType == "io":
                cols = globals()[classes[cType]].get_columns(self.cgroupType)
            else:
                cols = globals()[classes[cType]].get_columns()
            columns.extend(cols)
        return columns

    def collect(self):
        while True:
            threads = []
            for ctr in self.ctrCollects:
                # run thread for each container
                thread = Thread(target=CtrCollect.collect, args=(ctr,))
                thread.start()
                threads.append(thread)

            time.sleep(self.interval)

            for thread in threads:  # unnecessary?
                thread.join()

    def filter_collects(self, collects):
        return [cType for cType in collects if cType in self.supportTypes]

    def init_ctr_collects(self):
        ctrCollects = []
        for ctrID in self.ctrIDs:
            cc = CtrCollect(
                self.cgroupType,
                self.driver,
                self.output,
                self.cgroupPrefix,
                ctrID,
                self.collects,
            )
            cc.add_metadata()
            ctrCollects.append(cc)
        return ctrCollects


class CtrCollect(BaseCollect):
    def __init__(self, v, driver, output, cgroupPath, ctrID, collects):
        self.v = v
        self.driver = driver
        self.cgroupPath = cgroupPath
        self.ctrID = ctrID
        self.collects = self.init_collects(collects)
        self.outputFile = "/".join([output, ctrID + "-container.data"])

    def init_collects(self, cTypes):
        """create real collection classess"""
        collections = []
        classes = supHelper.getCtrClasses()
        for c in cTypes:
            if c == "rdt":
                collections.append(
                    globals()[classes[c]](
                        None, self.v, self.driver, self.ctrID, self.cgroupPath
                    )
                )
            else:
                collections.append(
                    globals()[classes[c]](
                        self.v, self.driver, self.ctrID, self.cgroupPath
                    )
                )
        return collections

    def add_metadata(self):
        """add metadata for the container at the begining of output file"""
        # container id, machine id
        # image name
        # pod name
        # cpu_shares, cpu_request, cpu_limit, storage
        pass

    def collect(self):
        data = []
        for c in self.collects:
            data.extend(c.collect())
        self.write(data, self.outputFile)


class CtrBaseCollect(object):
    def __init__(self, v, driver, ctrID, cgroupPrefix="/sys/fs/cgroup"):
        """for non-container collect, set v and ctrID as none"""
        self.v = v
        self.driver = driver
        self.ctrID = ctrID
        self.cgroupPrefix = cgroupPrefix

    def get_path(self, cgtype):
        return cgroupHelper.get_path(
            cgtype,
            v=self.v,
            driver=self.driver,
            ctrID=self.ctrID,
            cgroupPrefix=self.cgroupPrefix,
        )


class CtrCpuCollect(CtrBaseCollect):
    def __init__(self, v, driver, ctrID, cgroupPath="/sys/fs/cgroup"):
        """for non-container collect, set v and ctrID as none"""
        super().__init__(v, driver, ctrID, cgroupPath)

    @classmethod
    def get_columns(cls):
        cols = ["cpu_usage", "nr_throttled", "throttled_usec"]
        return cols

    def collect_v2(self):
        def read():
            with open(cpuUsagePath) as f:
                return {line.split()[0]: int(line.split()[1]) for line in f.readlines()}

        path = self.get_path("cpu")
        cpuUsagePath = "/".join([path, "cpu.stat"])

        interval = 1
        before_cpu_stat = read()
        time.sleep(interval)
        after_cpu_stat = read()

        cpuUsage = (after_cpu_stat["usage_usec"] - before_cpu_stat["usage_usec"])  * 1e-6 / interval
        nr_throttled = after_cpu_stat["nr_throttled"] - before_cpu_stat["nr_throttled"]
        throttled_usec = (
            after_cpu_stat["throttled_usec"] - before_cpu_stat["throttled_usec"]
        )

        return [
            "{:.3f}".format(item) for item in [cpuUsage, nr_throttled, throttled_usec]
        ]

    def collect_v1(self):
        def read():
            d = {}
            with open(cpuUsagePath) as f:
                usage = f.readline()

            with open(cpuStatPath) as f:
                for line in f.readlines():
                    key, value = line.split()
                    d[key] = value
            return usage, d["nr_throttled"], d["throttled_usec"]

        path = self.get_path("cpu")
        cpuUsagePath = "/".join([path, "cpuacct.usage"])
        cpuStatPath = "/".join([path, "cpu.stat"])

        interval = 1
        cstart, nr_throttled_before, throttled_usec_before = read()
        time.sleep(interval)
        cstop, nr_throttled_after, throttled_usec_after = read()

        cpuUsage = (cstop - cstart) * 1e-9 / interval
        return [
            "{:.3f}".format(item)
            for item in [
            cpuUsage,
            nr_throttled_after - nr_throttled_before,
            throttled_usec_after - throttled_usec_before,
            ]
        ]

    def collect(self):
        if self.v == "v1":
            return self.collect_v1()
        if self.v == "v2":
            return self.collect_v2()


class CtrMemCollect(CtrBaseCollect):
    """distinguish cache?"""

    def __init__(self, v, driver, ctrID, cgroupPath="/sys/fs/cgroup"):
        super().__init__(v, driver, ctrID, cgroupPath)

    @classmethod
    def get_columns(cls):
        return ["mem.usage", "mem.rss"]

    def collect_v1(self):
        # mem.usage: memory.usage_in_bytes
        # rss: memory.stat[rss]
        path = self.get_path("memory")
        usagePath = "/".join([path, "memory.usage_in_bytes"])
        with open(usagePath, "r") as f:
            usage = f.readline().rstrip()

        statPath = "/".join([path, "memory.stat"])
        with open(statPath, "r") as f:
            d = {line.split()[0]: line.split()[1] for line in f.readlines()}
        rss = d["rss"]
        return [usage, rss]

    def collect_v2(self):
        # mem.usage: memory.current
        # rss: memory.stat[anon] + memory.swap.current
        path = self.get_path("memory")

        curPath = "/".join([path, "memory.current"])
        with open(curPath, "r") as f:
            usage = f.readline().rstrip()

        swapCurPath = "/".join([path, "memory.swap.current"])
        with open(swapCurPath, "r") as f:
            swap = f.readline().rstrip()

        statPath = "/".join([path, "memory.stat"])
        with open(statPath, "r") as f:
            d = {line.split()[0]: line.split()[1] for line in f.readlines()}
        rss = int(d["anon"]) + int(swap)

        return [usage, rss]

    def collect(self):
        if self.v == "v1":
            return self.collect_v1()
        if self.v == "v2":
            return self.collect_v2()


class CtrIoCollect(CtrBaseCollect):
    """device iostat for each container

    Attributes:
        devices: ["8:0", "253:0"]
    """

    devices = []

    def __init__(self, v, driver, ctrID, cgroupPath="/sys/fs/cgroup"):
        super().__init__(v, driver, ctrID, cgroupPath)

    @classmethod
    def get_columns(cls, v):
        # columns: ["8:0-read", "8:0-write", "253:0-read", "253:0-write"]
        # hard to predicate columns... so write each device...
        columns = []
        drive_glob = "/sys/block"
        CtrIoCollect.devices = os.listdir(drive_glob)

        for device in CtrIoCollect.devices:
            columns.append(device + "-read")
            columns.append(device + "-write")
        return columns

    def collect_v1(self):
        """cgroupv1 version of ctr io
        An example blkio.throttle.io_service_bytes output follows:

        8:0 Read 1155072
        8:0 Write 0
        ...
        253:0 Read 1155072
        253:0 Write 0
        ...
        Total 2310144
        """
        path = self.get_path("blkio")
        io_serv_path = "/".join([path, "blkio.throttle.io_service_bytes"])

        def read():
            io = collections.defaultdict(dict)
            with open(io_serv_path, "r") as f:
                for line in f.readlines():
                    line = line.rstrip().split()
                    io[line[0]][line[1]] = line[2]
            return io

        before_io = read()
        time.sleep(1)
        after_io = read()

        return self.match_device(before_io, after_io)

    def collect_v2(self):
        """cgroupv2 version of ctr io
        An example io.stat output follows:

          8:16 rbytes=1459200 wbytes=314773504 rios=192 wios=353
          8:0 rbytes=90430464 wbytes=299008000 rios=8950 wios=125
        """
        path = self.get_path("blkio")
        io_stat_path = "/".join([path, "io.stat"])

        def read():
            io = collections.defaultdict(dict)
            with open(io_stat_path, "r") as f:
                for line in f.readlines():
                    line = line.rstrip().split()
                    io[line[0]] = {
                        elem.split("=")[0]: elem.split("=")[1] for elem in line[1:]
                    }
            return io

        before_io = read()
        time.sleep(1)
        after_io = read()

        return self.match_device(before_io, after_io)

    def match_device(self, before_io, after_io):
        data = []
        for device in CtrIoCollect.devices:
            if device in before_io and device in after_io:
                if self.v == "v2":
                    data.extend(
                        [
                            after_io[device]["rbytes"] - before_io[device]["rbytes"],
                            after_io[device]["wbytes"] - before_io[device]["rbytes"],
                        ]
                    )
                else:
                    data.extend(
                        [
                            after_io[device]["Read"] - before_io[device]["Read"],
                            after_io[device]["Write"] - before_io[device]["Write"],
                        ]
                    )
            elif (
                device not in before_io and device in after_io
            ):  # this happans in cgroupv2
                if self.v == "v2":
                    data.extend(
                        [after_io[device]["rbytes"], after_io[device]["wbytes"]]
                    )
                else:
                    data.extend([after_io[device]["Read"], after_io[device]["Write"]])
            else:
                data.extend([-1, -1])
        return data

    def collect(self):
        if self.v == "v1":
            return self.collect_v1()
        if self.v == "v2":
            return self.collect_v2()


class CtrNetworkCollect(CtrBaseCollect):
    """not used
    So far we don't collect ctr network info.

    To get net info of ctr, inspect the pid of ctr process,
    and add pid to node section in config.toml
    """

    def __init__(self, v, driver, ctrID, cgroupPath="/sys/fs/cgroup"):
        super().__init__(v, driver, ctrID, cgroupPath)

    @classmethod
    def get_columns(cls, v):
        return []

    def collect(self):
        return []
